#!/usr/bin/env python3
# Copyright (c) 2021 maminjie <canpool@163.com>
# SPDX-License-Identifier: MulanPSL-2.0

"""
reference to https://nvd.nist.gov/General/News/New-NVD-CVE-CPE-API-and-SOAP-Retirement
"""

import json
import urllib
import urllib.request
import urllib.parse
import urllib.error


class Nvd(object):
    """
    nvd base class
    """
    def __init__(self):
        self.params = {
            "url": "https://services.nvd.nist.gov/rest/json"
            }
        self.headers = {
            'User-Agent':'Mozilla/5.0 (Windows NT 10.0; WOW 64; rv:50.0) Gecko/20100101 Firefox/50.0'
            }

    def http_request(self, url, data=None, headers={}, method=None):
        url = self.params["url"] + url
        req = urllib.request.Request(url=url, data=data, headers=headers, method=method)
        try:
            result = urllib.request.urlopen(req)
            return result.read().decode("utf-8")
        except urllib.error.HTTPError as e:
            print("{} ERROR: ".format(method) + str(url).split("?")[0])
            print("{} ERROR: ".format(method) + str(e.code) + " " + e.reason)
            return None
        except urllib.error.URLError as e:
            print("{} ERROR: ".format(method) + str(url).split("?")[0])
            print("{} ERROR: ".format(method) + e.reason)
            return None

    def get_request(self, url, headers=None):
        if headers is None:
            headers = self.headers.copy()
        return self.http_request(url=url, headers=headers, method="GET")

    def get_json(self, url):
        headers = self.headers.copy()
        headers["Content-Type"] = "application/json;charset=UTF-8"
        resp = self.get_request(url, headers)
        if resp:
            return json.loads(resp)
        return resp

    def get(self, url, param=""):
        return self.get_json(url + param)

    def _parse_query_param(self, **kwargs):
        query_str = []
        for key, value in kwargs.items():
            query_str.append("{key}={value}".format(key=key, value=value))
        return "&".join(query_str)

    def get_one_cve(self, cveId):
        url_template = "/cve/1.0/{cveId}"
        url = url_template.format(cveId=cveId)
        return self.get(url)

    def get_cves(self, **kwargs):
        if kwargs:
            param = self._parse_query_param(**kwargs)
            url = "/cves/1.0?{query_param}".format(query_param=param)
            return self.get(url)
        return None


import os
import argparse
import re
import time


class NvdStat(Nvd):
    """Statistics"""
    def __init__(self):
        super(NvdStat, self).__init__()

    def _convert_time_with_zone(self, time_str):
        tmp_str = time_str
        time_array = time.strptime(tmp_str, "%Y-%m-%dT%H:%MZ")
        tmp_str = time.strftime("%Y-%m-%d", time_array)
        return tmp_str

    def _csv_quote_string(self, data):
        tmp = re.sub("\"", "\"\"", data)
        return "\"{}\"".format(tmp)

    def _get_list(self, arg):
        """
        convert the arg to list, the arg is a string or file
        the content of string or file is separated by spaces

        eg:
            arg = " a b  c d  "
            get_list(arg)    => ['a','b','c','d']
        """
        lst = []
        if os.path.isfile(arg):
            with open(arg, 'r') as f:
                for ln in f:
                    lst.extend(ln.split())
        else:
            lst = list(arg.split())
        return lst

    def _print_cve(self, resp):
        totalResults = resp.get('totalResults')
        if not totalResults:
            return totalResults
        result = resp.get('result')
        cve_items = result.get('CVE_Items')
        for cve_item in cve_items:
            cve = cve_item.get('cve')
            id = cve.get('CVE_data_meta').get('ID')
            description = cve.get("description").get("description_data")[0].get("value")
            description = self._csv_quote_string(description)
            v3Score = v2Score = ""
            impact = cve_item.get('impact')
            baseMetricV3 = impact.get('baseMetricV3')
            if baseMetricV3:
                cvssV3 = baseMetricV3.get('cvssV3')
                if cvssV3:
                    v3Score = cvssV3.get('baseScore')
            baseMetricV2 = impact.get('baseMetricV2')
            if baseMetricV2:
                cvssV2 = baseMetricV2.get('cvssV2')
                if cvssV2:
                    v2Score = cvssV2.get('baseScore')
            publishedDate = cve_item.get('publishedDate')
            publishedData = self._convert_time_with_zone(publishedDate)
            print("{},{},{},{},{}".format(id, description, v3Score, v2Score, publishedData))
        return totalResults

    def get_one_cve(self, cveId):
        print("id,desc,cvssV3,cvssV2,date")
        cves = self._get_list(cveId)
        for cve in cves:
            resp = super(NvdStat, self).get_one_cve(cve)
            if resp:
                self._print_cve(resp)

    def get_cves(self, keyword):
        print("id,desc,cvssV3,cvssV2,date")
        index = 0
        perpage = 100
        resp = super(NvdStat, self).get_cves(keyword=keyword, startIndex=index,
            resultsPerPage=perpage, isExactMatch="true")
        while resp:
            totalResults = self._print_cve(resp)
            if not totalResults:
                break
            index += totalResults
            resp = super(NvdStat, self).get_cves(keyword=keyword, startIndex=index,
                resultsPerPage=perpage, isExactMatch="true")


def parse_command_line():
    params = argparse.ArgumentParser()
    params.add_argument("-k", "--keyword", type=str,
                        help="The keyword of soft package")
    params.add_argument("-c", "--cve", type=str,
                        help="The cve file or list")
    args = params.parse_args()
    return args


def main():
    args = parse_command_line()
    nvd = NvdStat()
    if args.keyword:
        nvd.get_cves(args.keyword)
    elif args.cve:
        nvd.get_one_cve(args.cve)


if __name__ == "__main__":
    main()
